{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7cd67378-f73a-4e49-94f3-27a73f49723d",
   "metadata": {},
   "source": [
    "## Prepare raw graph files of AqSol dataset from their source SMILES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78fea400-29f9-4adb-ae8f-1886af43ed83",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rdkit import Chem\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "import random\n",
    "import numpy as np\n",
    "from itertools import compress\n",
    "from rdkit.Chem.Scaffolds import MurckoScaffold\n",
    "from collections import defaultdict\n",
    "import os\n",
    "import pandas as pd\n",
    "import time\n",
    "import pickle\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
    "\n",
    "from rdkit import RDLogger\n",
    "RDLogger.DisableLog('rdApp.*')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f29fb220-1bc0-4e15-ab68-80e992573424",
   "metadata": {},
   "source": [
    "### Download csv source file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "852c9490-4cbe-4e63-8aee-f7725086678e",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.isfile('aqsol.csv'):\n",
    "    print('downloading..')\n",
    "    # The download link is present on this webpage: \n",
    "    # https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/OVHAW8\n",
    "    !curl https://www.dropbox.com/s/26zoivsx3s1qr3q/curated-solubility-dataset.csv?dl=1 -o aqsol.csv -J -L -k\n",
    "    print('download complete')\n",
    "else:\n",
    "    print('File already downloaded')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8c61e9f-d773-4707-8a63-ec4532fb1445",
   "metadata": {},
   "outputs": [],
   "source": [
    "aqsol_df = pd.read_csv('aqsol.csv') # read dataset\n",
    "smiles_list = list(aqsol_df['SMILES']) # get smiles strings from file\n",
    "labels_list = np.asarray(aqsol_df['Solubility']) # get solubility values from file\n",
    "mol_list = [Chem.MolFromSmiles(smiles) for smiles in smiles_list]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f23b08fd-16a1-4052-b6c1-6083eb47b2a0",
   "metadata": {},
   "source": [
    "### Scaffold Splitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "204ed09b-3451-4920-9748-149da01384c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Code snippet from\n",
    "# https://groups.google.com/g/open-graph-benchmark/c/0KML08gNVcM/m/99SeJ0zpAAAJ?pli=1\n",
    "\n",
    "def generate_scaffold(smiles, include_chirality=False):\n",
    "    \"\"\"\n",
    "    Obtain Bemis-Murcko scaffold from smiles\n",
    "    :param smiles:\n",
    "    :param include_chirality:\n",
    "    :return: smiles of scaffold\n",
    "    \"\"\"\n",
    "    scaffold = MurckoScaffold.MurckoScaffoldSmiles(\n",
    "        smiles=smiles, includeChirality=include_chirality)\n",
    "    return scaffold\n",
    "# # test generate scaffold\n",
    "# s = 'Cc1cc(Oc2nccc(CCC)c2)ccc1'\n",
    "# scaffold = generate_scaffold(s)\n",
    "# assert scaffold == 'c1ccc(Oc2ccccn2)cc1'\n",
    "\n",
    "def scaffold_split(smiles_list, frac_train=0.8, frac_valid=0.1, frac_test=0.1):\n",
    "    \"\"\"\n",
    "    Adapted from https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py\n",
    "    Split dataset by Bemis-Murcko scaffolds. Deterministic split\n",
    "    :param smiles_list: list of smiles\n",
    "    :param frac_train:\n",
    "    :param frac_valid:\n",
    "    :param frac_test:\n",
    "    :return: list of train, valid, test indices corresponding to the\n",
    "    scaffold split\n",
    "    \"\"\"\n",
    "    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0)\n",
    "\n",
    "    # create dict of the form {scaffold_i: [idx1, idx....]}\n",
    "    all_scaffolds = {}\n",
    "    for i, smiles in enumerate(smiles_list):\n",
    "        scaffold = generate_scaffold(smiles, include_chirality=True)\n",
    "        if scaffold not in all_scaffolds:\n",
    "            all_scaffolds[scaffold] = [i]\n",
    "        else:\n",
    "            all_scaffolds[scaffold].append(i)\n",
    "\n",
    "    # sort from largest to smallest sets\n",
    "    all_scaffolds = {key: sorted(value) for key, value in all_scaffolds.items()}\n",
    "    all_scaffold_sets = [\n",
    "        scaffold_set for (scaffold, scaffold_set) in sorted(\n",
    "            all_scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)\n",
    "    ]\n",
    "\n",
    "    # get train, valid test indices\n",
    "    train_cutoff = frac_train * len(smiles_list)\n",
    "    valid_cutoff = (frac_train + frac_valid) * len(smiles_list)\n",
    "    train_idx, valid_idx, test_idx = [], [], []\n",
    "    for scaffold_set in all_scaffold_sets:\n",
    "        if len(train_idx) + len(scaffold_set) > train_cutoff:\n",
    "            if len(train_idx) + len(valid_idx) + len(scaffold_set) > valid_cutoff:\n",
    "                test_idx.extend(scaffold_set)\n",
    "            else:\n",
    "                valid_idx.extend(scaffold_set)\n",
    "        else:\n",
    "            train_idx.extend(scaffold_set)\n",
    "\n",
    "    assert len(set(train_idx).intersection(set(valid_idx))) == 0\n",
    "    assert len(set(train_idx).intersection(set(test_idx))) == 0\n",
    "    assert len(set(test_idx).intersection(set(valid_idx))) == 0\n",
    "\n",
    "    return train_idx, valid_idx, test_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "830939b1-6c27-4eff-8320-febb9f70f62e",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_idx, valid_idx, test_idx = scaffold_split(smiles_list)\n",
    "print(len(train_idx), len(valid_idx), len(test_idx))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2e6a992-e378-4376-bfea-2e20ba555fa5",
   "metadata": {},
   "source": [
    "### Create Atom and Bond Dicts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8c5b9df-8c2c-4a27-9932-7360d264eb33",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Code snippet from\n",
    "# https://github.com/xbresson/IPAM_Tutorial_2019/blob/master/04_molecule_regression/dictionaries.py\n",
    "class Dictionary(object):\n",
    "    \"\"\"\n",
    "    worddidx is a dictionary\n",
    "    idx2word is a list\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        self.word2idx = {}\n",
    "        self.idx2word = []\n",
    "        self.word2num_occurence = {}\n",
    "        self.idx2num_occurence = []\n",
    "\n",
    "    def add_word(self, word):\n",
    "        if word not in self.word2idx:\n",
    "            # dictionaries\n",
    "            self.idx2word.append(word)\n",
    "            self.word2idx[word] = len(self.idx2word) - 1\n",
    "            # stats\n",
    "            self.idx2num_occurence.append(0)\n",
    "            self.word2num_occurence[word] = 0\n",
    "\n",
    "        # increase counters    \n",
    "        self.word2num_occurence[word]+=1\n",
    "        self.idx2num_occurence[  self.word2idx[word]  ] += 1\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.idx2word)\n",
    "\n",
    "\n",
    "def augment_dictionary(atom_dict, bond_dict, list_of_mol ):\n",
    "\n",
    "    \"\"\"\n",
    "    take a lists of rdkit molecules and use it to augment existing atom and bond dictionaries\n",
    "    \"\"\"\n",
    "    for idx,mol in enumerate(list_of_mol):\n",
    "\n",
    "        for atom in mol.GetAtoms():\n",
    "            atom_dict.add_word( atom.GetSymbol() )\n",
    "\n",
    "        for bond in mol.GetBonds():\n",
    "            bond_dict.add_word( str(bond.GetBondType()) )\n",
    "\n",
    "        # compute the number of edges of type 'None'\n",
    "        N=mol.GetNumAtoms()\n",
    "        if N>2:\n",
    "            E=N+math.factorial(N)/(math.factorial(2)*math.factorial(N-2)) # including self loop\n",
    "            num_NONE_bonds = E-mol.GetNumBonds()\n",
    "            bond_dict.word2num_occurence['NONE']+=num_NONE_bonds\n",
    "            bond_dict.idx2num_occurence[0]+=num_NONE_bonds\n",
    "\n",
    "\n",
    "def make_dictionary(list_of_mol):\n",
    "\n",
    "    \"\"\"\n",
    "    the list of smiles (train, val and test) and build atoms and bond dictionaries\n",
    "    \"\"\"\n",
    "    atom_dict=Dictionary()\n",
    "    bond_dict=Dictionary()\n",
    "    bond_dict.add_word('NONE')\n",
    "    print('making dictionary')\n",
    "    augment_dictionary(atom_dict, bond_dict, list_of_mol )\n",
    "    print('complete')\n",
    "    return atom_dict, bond_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "657f291b-8f82-450d-93df-134af7a02633",
   "metadata": {},
   "outputs": [],
   "source": [
    "atom_dict, bond_dict = make_dictionary(mol_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d0a6b55-73b5-4538-a0a9-868e17139eb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(atom_dict.word2idx)\n",
    "print(atom_dict.word2num_occurence)\n",
    "print(bond_dict.word2idx)\n",
    "print(bond_dict.word2num_occurence)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90973857-5f9d-4303-8038-b17642a6e0c3",
   "metadata": {},
   "source": [
    "### Create graph objects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5718556a-9dcd-4e57-a926-ca3ee20298f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mol_to_graph(mol, solubility):\n",
    "    \"\"\"\n",
    "        mol is a rdkit mol object\n",
    "    \"\"\"\n",
    "    no_bond_flag = False\n",
    "    node_feat = np.array([atom_dict.word2idx[atom.GetSymbol()] for atom in mol.GetAtoms()], dtype = np.int64)\n",
    "\n",
    "    if len(mol.GetBonds()) > 0: # mol has bonds\n",
    "        edges_list = []\n",
    "        edge_features_list = []\n",
    "        for bond in mol.GetBonds():\n",
    "            i = bond.GetBeginAtomIdx()\n",
    "            j = bond.GetEndAtomIdx()\n",
    "            edge_feature = bond_dict.word2idx[str(bond.GetBondType())]\n",
    "            # add edges in both directions\n",
    "            edges_list.append((i, j))\n",
    "            edge_features_list.append(edge_feature)\n",
    "            edges_list.append((j, i))\n",
    "            edge_features_list.append(edge_feature)\n",
    "\n",
    "        # edge_index: Graph connectivity in COO format with shape [2, num_edges]\n",
    "        edge_index = np.array(edges_list, dtype = np.int64).T\n",
    "\n",
    "        # edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]\n",
    "        edge_feat = np.array(edge_features_list, dtype = np.int64)\n",
    "\n",
    "    else:   # mol has no bonds\n",
    "        no_bond_flag = True\n",
    "        # print(mol)\n",
    "        edge_index = np.empty((2, 0), dtype = np.int64)\n",
    "        edge_feat = np.empty((0, 1), dtype = np.int64)\n",
    "        \n",
    "    graph_object = (node_feat, edge_feat, edge_index, solubility)    \n",
    "    return graph_object, no_bond_flag"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d787cbc-ec04-4978-97d6-a1585e3fbab0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_graph_objects(split_idx, mol_list, labels_list):\n",
    "    count_no_bond = 0\n",
    "    split_graph_objects =[]\n",
    "\n",
    "    for idx in split_idx:\n",
    "        graph_object, no_bond_flag = mol_to_graph(mol_list[idx], labels_list[idx])\n",
    "        split_graph_objects.append(graph_object)\n",
    "        len(graph_object[0])\n",
    "        if no_bond_flag:\n",
    "            #print(\"no bond graph, num nodes:\",len(graph_object[0]))\n",
    "            count_no_bond += 1\n",
    "    print(\"Total graphs with no bonds: \", count_no_bond)\n",
    "    return split_graph_objects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f4dada9-fc0d-4ce5-9aee-ca8ef394bcf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Train split..\")\n",
    "train_graph_objects = get_graph_objects(train_idx, mol_list, labels_list)\n",
    "print(\"Total graphs:\", len(train_graph_objects))\n",
    "print(\"Valid split..\")\n",
    "valid_graph_objects = get_graph_objects(valid_idx, mol_list, labels_list)\n",
    "print(\"Total graphs:\", len(valid_graph_objects))\n",
    "print(\"Train split..\")\n",
    "test_graph_objects = get_graph_objects(test_idx, mol_list, labels_list)\n",
    "print(\"Total graphs:\", len(test_graph_objects))\n",
    "print(\"Done\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9780c5d9-68d9-4ce6-aff1-1bb31973c589",
   "metadata": {},
   "source": [
    "### Save pkl files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bae84f2-ff5d-45a0-8ec6-1ee7c8f59dbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "savedir = './asqol_graph_raw'\n",
    "if not os.path.exists(savedir):\n",
    "    os.makedirs(savedir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e19cf3fc-14d5-4df2-b8e6-5c9f0cfe401c",
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "with open(savedir+'/train.pickle','wb') as f:\n",
    "    pickle.dump(train_graph_objects,f)\n",
    "with open(savedir+'/val.pickle','wb') as f:\n",
    "    pickle.dump(valid_graph_objects,f)\n",
    "with open(savedir+'/test.pickle','wb') as f:\n",
    "    pickle.dump(test_graph_objects,f)\n",
    "with open(savedir+'/atom_dict.pickle','wb') as f:\n",
    "    pickle.dump(atom_dict,f)\n",
    "with open(savedir+'/bond_dict.pickle','wb') as f:\n",
    "    pickle.dump(bond_dict,f)\n",
    "print('Time (sec):',time.time() - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e916b38-35b7-4e3e-8e2e-ab79c99dc322",
   "metadata": {},
   "outputs": [],
   "source": [
    "!zip -r aqsol_graph_raw.zip asqol_graph_raw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5431b3bb-fe29-43ee-9e55-ff71b1916c12",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b205b79b-4474-49a7-b093-52541b8de42e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
